package com.ml4ai.demo;

import com.ml4ai.nn.*;
import com.ml4ai.nn.core.Toolkit;
import com.ml4ai.nn.core.Variable;
import com.ml4ai.nn.core.optimizers.Moment;
import com.ml4ai.nn.core.optimizers.NNOptimizer;
import lombok.Getter;
import lombok.SneakyThrows;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

/**
 * 生成式对抗网络
 */
public class GenerativeAdversarialNet {

    @Getter
    public static class GAN extends BaseForwardNetwork {

        private ForwardNetwork generator;
        private ForwardNetwork discriminator;

        public GAN() {
            /**
             * 生成器定义
             */
            generator = new SequentialForward(
                    new Linear(3, 128),
                    new Relu(),
                    new Linear(128, 256),
                    new Relu(),
                    new Linear(256, 28 * 28)
            );
            /**
             * 判别器
             */
            discriminator = new SequentialForward(
                    new Linear(28 * 28, 256),
                    new Relu(),
                    new Linear(256, 128),
                    new Relu(),
                    new Linear(128, 1),
                    new Sigmoid()
            );
            add(generator);
            add(discriminator);
        }

        @Override
        public Variable[] getParameters() {
            return super.getParameters();
        }

        public Variable[] getDiscriminatorParameters() {
            return discriminator.getParameters();
        }

        public Variable[] getGeneratorParameters() {
            return generator.getParameters();
        }
    }

    private static DataSetIterator dataSetIt = null;

    static {
        try {
            dataSetIt = new MnistDataSetIterator(1, 1);
        } catch (Exception e) {
            System.out.println("数据集获取失败");
        }
    }

    public static INDArray takeSample() {
        INDArray data = dataSetIt.next().getFeatures();
        return data;
    }

    @SneakyThrows
    public static void main(String[] argv) {
        GAN gan = new GAN();
        NNOptimizer generatorOptimizer = new Moment(gan.getGeneratorParameters(), 1e-3, 0.95D);
        NNOptimizer discriminatorOptimizer = new Moment(gan.getDiscriminatorParameters(), 1e-3, 0.95D);

        Toolkit tool = new Toolkit();
        for (int i = 0; i < 10000; i++) {
            Variable seed = new Variable(Nd4j.randn(new int[]{1, 3}));

            Variable generated = gan.generator.forward(seed)[0];
            Variable real = new Variable(takeSample());

            Variable real_active = gan.discriminator.forward(real)[0].mean();
            Variable generated_active = gan.discriminator.forward(generated)[0].mean();

            Variable discriminator_target = real_active.log(Math.E).add(new Variable(1).sub(generated_active).log(Math.E));
            Variable generator_target = generated_active.log(Math.E);

            Variable d_loss = new Variable(0).sub(discriminator_target);
            Variable g_loss = new Variable(0).sub(generator_target);

            tool.grad2zero(d_loss);
            tool.backward(d_loss);
            discriminatorOptimizer.update();

            tool.grad2zero(g_loss);
            tool.backward(g_loss);
            generatorOptimizer.update();

            if (i % 100 == 0) {
                System.out.println("loss:" + g_loss.data.scalar);
            }

        }
    }


}

